/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Contains classes for the thermal_equilibrium_grain class.

#ifndef __thermal_equilibrium_grain__
#define __thermal_equilibrium_grain__

#include <iostream>
#include "blitz/array.h"
#include "blackbody.h"
#include "solvert.h"
#include "mcrx-types.h"
#include "mcrx-units.h"
#include "mono_poly_abstract.h"
#include <vector>
#include "misc.h"
#include <string>
#include "threadlocal.h"
#include "interpolatort.h"
#include "vecops.h"
#include "grain.h"
#include "preferences.h"
#include "boost/thread/thread.hpp"
#include "config.h"

#ifdef WITH_CUDA
#include "cuda_grain_temp.h"
#endif

namespace mcrx {
  class thermal_equilibrium_grain;
};

//for testing purposes
int main(int, char**);

/** A class representing a grain species, ie a set of grains of
    different sizes with a specific composition, which is in thermal
    equilibrium. The class knows how to calculate grain temperatures
    and the resulting emission SED. The thermal_equilibrium_grain
    class methods are not reentrant and can not be used by multiple
    threads. You must make a copy of the object for each thread. 

    The math behind how this works: 

    The emitted luminosity we get by integrating the blackbody
    specific intensity over area and solid angle. The emission is
    Lambertian, ie an area element emits as \f[ \mathrm{d}L =
    \mathrm{d}A \cos \theta \mathrm{d}\Omega B_\lambda(\lambda)
    \mathrm{d}\lambda.\f] Integrating \f$\mathrm{d}A\f$ gives the
    total area \f$A=4 \pi R^2\f$, which expressed in cross-section is
    \f$ 4 \sigma(\lambda) \f$. Integrating the \f$ \cos \theta
    \mathrm{d}\Omega \f$ gives \f$\pi\f$. We are left with
    \f[L_\mathrm{emitted} = 4 \pi \! \int \! \sigma(\lambda)
    B_\lambda(\lambda) \mathrm{d}\lambda.\f] (Note that the \f$4 pi\f$
    has nothing to do with steradians. The 4 is from expressing the
    area in terms of cross section, and \f$\pi\f$ from the solid angle
    integral.) If \f$\sigma\f$ is independent of wavelength, we
    recover the Stefan-Boltzmann law.

    In the same way, the absorbed luminosity is the integral of the
    intensity of the radiation field over area, solid angle and
    wavelength. The integral is completely analogous and comes out to
    be \f[L_\mathrm{absorbed} = 4 \pi \! \int \! \sigma(\lambda)
    I_\lambda(\lambda) \mathrm{d}\lambda.\f]

*/  
class mcrx::thermal_equilibrium_grain : public emitting_grain {
public:
  friend int ::main(int, char**);
  typedef T_float T_numtype;
protected:

  /// Thread object for the CPU calculation.
  template<typename t> class temp_thread;
  int n_threads_;
  bool bind_threads_;
  int bank_size_;

  /// absorbed luminosity for different grain sizes
  mutable T_float heating_;
  mutable int current_grain_;

  T_float accuracy_;
  bool use_cuda_;
  bool use_lookup_;

  mutable bool temp_table_updated_;
  typedef interpolator<T_float, T_float, 1> T_interpolator;
  typedef solver<thermal_equilibrium_grain>  T_solver;

  /// Interpolator is lookup table for grain temperatures as a
  /// function of absorbed luminosity.
  mutable std::vector<T_interpolator> temp_interpolators_;

  mutable T_solver s_;

  /** Set up internal variables. */
  virtual void setup();
  void setup_interpolators() const;

  /** Calls the mcrxcuda routines to calculate the grain SEDs. */
  template<typename T>
  void CUDA_calculate_SED_from_intensity (const blitz::ETBase<T>& intensity,
					  const array_1& mdust,
					  const array_1& dn,
					  array_2 sed, 
					  bool add_to_sed,
					  array_2 heating=array_2(),
					  array_2 temp=array_2()) const;

  /** Starts the CPU threads for calculating the grain SEDs. */
  template<typename T>
  void start_threads(const blitz::ETBase<T>& intensity, 
		     const array_1& m_dust, 
		     const array_1& dn, array_2& sed, 
		     bool add_to_sed,
		     size_t block_size,
		     array_2* heating,
		     array_2* temp) const;

  T_float interpolate_T(int s, T_float) const;  

  /** Solve for grain temp from the specified initial guess using the
      solver. */
  T_float solve_for_T(int s, T_float heating, T_float guess) const;
  /** Solve for grain temp using our formula for initial guess. */
  T_float solve_for_T(int s, T_float heating) const {
    const T_float xx = log10(heating) - 
      1.39*log10(asizes()(s)) + 
      0.126*pow(log10(asizes()(s)),2);
    const T_float Tguess =
      pow(10,1.86+0.189*xx+3.41e-3*xx*xx);
    return solve_for_T(s, heating, (Tguess<5)? 5 : Tguess);
  };

  
public:
  /** Constructor takes a file containing cross section information
      and the minimum wavelength for which blackbody emission is to be
      considered. (The latter is strictly for efficiency purposes, no
      point in calculating dust emission at UV wavelegths.) */
  thermal_equilibrium_grain (const std::string& file_name, 
			     const Preferences& p, T_float minsize=0, 
			     T_float maxsize=blitz::huge(T_float()));
  thermal_equilibrium_grain (const thermal_equilibrium_grain& g) :
    n_threads_(g.n_threads_),
    bind_threads_(g.bind_threads_), 
    bank_size_(g.bank_size_),
    emitting_grain(g), s_(*this),
    // need to make sure all array members are real copies
    //absorption_(make_thread_local_copy(g.absorption_)),
    current_grain_(g.current_grain_),
    accuracy_(g.accuracy_),
    use_lookup_(g.use_lookup_),
    // but the interpolators can be shared because they are
    // reentrant. we need to make a copy of the vector though.
    // actually interpolators should not be shared because it causes
    // poor parallel performance (on altix). They should be allocated
    // locally.
    temp_table_updated_(g.temp_table_updated_),
    temp_interpolators_(g.temp_interpolators_.begin(),
			g.temp_interpolators_.end())
    
  {};

  virtual boost::shared_ptr<emitting_grain> 
  clone() const {
    return boost::shared_ptr<emitting_grain>
      (new thermal_equilibrium_grain(*this));};

  //why are these public?

  /** Function definition for the solver. Returns the difference
      between the heating and the emitted luminosity at the specified
      temperature for the current grain. */
  T_float func (T_float T) const {
    assert(T==T);
    const double invT=1./T;
    // using stencil is slow because it duplicates all B_lambda
    // evaluations, so we use a low-level equivalent here.

    T_float sum=0;
    // manually integrate using simple trapezoidal rule

    int i = n_elambda()-1;
    assert(invelambda().size()==n_elambda());
    assert(esigma().extent(blitz::secondDim)==n_elambda());

    const T_float* plam = elambda().data() + i;
    const T_float* pilam = invelambda().data() + i;
    const T_float* psig = &esigma()(current_grain_, i);
    assert(esigma().stride(blitz::secondDim)==1);
    T_float xold=*plam;
    T_float yold=*psig * B_lambda(*pilam, invT);
    for(; i>0; --i) {
      T_float xnew= *(--plam);
      T_float ynew= *(--psig) * B_lambda(*(--pilam), invT);
      // note that we integrate backward and we want to get negative
      // numbers so we add here
      sum += 0.5 * (yold + ynew) * (xnew - xold);
      xold=xnew;
      yold=ynew;
    }    
    const T_float result = 4.0*constants::pi*sum + heating_;

    // in debug mode, check these results against the real integration code
    DEBUG(1,					\
	  const T_float ref = heating_ -				\
	  4.0*constants::pi*integrate_quantity(esigma()(current_grain_,	\
							blitz::Range::all())* \
					       B_lambda(invelambda(), invT), \
					       elambda(), false);	\
	  assert((result-ref)/ref<1e-5););
    return result;
  };

  /** Derivative definition for the solver. Returns the derivative of
      the emitted luminosity at the specified temperature for the
      current grain. */
  T_float der (T_float T) const {
    assert(T==T);
    const double invT=1./T;    
    T_float sum=0;
    // manually integrate using simple trapezoidal rule

    int i = n_elambda()-1;
    const T_float* plam = elambda().data() + i;
    const T_float* pilam = invelambda().data() + i;
    const T_float* psig = &esigma()(current_grain_, i);
    T_float xold=*plam;
    T_float yold=*psig * dB_lambda_dT(*pilam, invT);
    for(; i>0; --i) {
      T_float xnew= *(--plam);
      T_float ynew= *(--psig) * dB_lambda_dT(*(--pilam), invT);
      // note that we integrate backward and we want to get negative
      // numbers so we add here
      sum += 0.5 * (yold + ynew) * (xnew - xold);
      xold=xnew;
      yold=ynew;
    }
    const T_float result = 4.0*constants::pi*sum;

    // in debug mode, check these results against the real integration code
    DEBUG(1,								\
	  const T_float ref =						\
	  -4.0*constants::pi*integrate_quantity(esigma()(current_grain_, \
							 blitz::Range::all())* \
						dB_lambda_dT(invelambda(), invT), \
						elambda(), false);	\
	  assert((result-ref)/ref<1e-5););
    return result;
  };


  /** This function calculates the emission spectrum of the grains
      based on the specified heating intensity and dust mass in the
      cells and the specified size distribution. The calculation is
      done in parallel or on a GPU, if applicable. For efficiency, the
      result is written to the existing array sed, or if add_to_sed is
      true, added to existing data in the sed array. */
  template <typename T>
  void calculate_SED_from_intensity_virtual (const blitz::ETBase<T>& intensity,
					     const array_1& mdust,
					     const array_1& dn,
					     array_2 sed, 
					     bool add_to_sed,
					     array_2* temp) const;
}; // thermal equilibrium grain


/** Thread object, executes the temperature calculation. Because the
    intensity is being passed as an expression template, it needs to
    be templated on that type. This means you can only send one type
    of intensity to one type of thread, but that shouldn't matter
    because they are "use-once" type of objects. */
template<typename T>
class mcrx::thermal_equilibrium_grain::temp_thread {
public:
  static const int cache_line_size = 128;

  /// To ensure threads don't share cache line.
  char padding [cache_line_size]; 
  bool use_lookup_;
  bool add_to_sed_;
  bool bind_thread_;
  int bank_size_;

  int thread_number_;

  thermal_equilibrium_grain g_;

  // each thread needs its own copy of the expression
  T intensity_;
  array_1 dnd_;
  array_1 m_dust_;
  
  array_2* heating_;
  array_2* temp_;
  array_2 sed_;
  array_1 tempsed_;

  /// list of cell blocks to be processed
  std::vector<std::pair<size_t,size_t> >& clist_;

  /// Reference to the mutex protecting the cell vector
  boost::mutex& cell_mutex_;


public:
  temp_thread(int tn,
	      const thermal_equilibrium_grain& g,
	      const blitz::ETBase<T>& i, const array_1& d,
	      const array_1& md,
	      array_2* h,
	      array_2* t,
	      array_2& s,
	      std::vector<std::pair<size_t,size_t> >& cl,
	      boost::mutex& cm,
	      bool ats, bool ul, bool bt, int bs) :
    thread_number_(tn), g_(g),
    intensity_(i.unwrap()),
    heating_(h),
    temp_(t),
    clist_(cl), cell_mutex_(cm), use_lookup_(ul), add_to_sed_(ats),
    bind_thread_(bt), bank_size_(bs)
  {
    dnd_.weakReference(d);
    m_dust_.weakReference(md);
    sed_.weakReference(s);
    tempsed_.resize(g_.esigma().extent(blitz::secondDim));
  };

  /** Copy constructor. Calls clone() for each of the grain_model
      objects, so each thread has a local copy. */
  temp_thread(const temp_thread& t) :
    thread_number_(t.thread_number_), g_(t.g_),
    intensity_(t.intensity_), dnd_(t.dnd_), 
    m_dust_(t.m_dust_), heating_(t.heating_), temp_(t.temp_), sed_(t.sed_),
    tempsed_(t.tempsed_),
    clist_(t.clist_), cell_mutex_(t.cell_mutex_), add_to_sed_(t.add_to_sed_), 
    use_lookup_(t.use_lookup_), bind_thread_(t.bind_thread_),
    bank_size_(t.bank_size_)
  {};

  void operator () ();
  void run_range(size_t cmin, size_t cmax);
}; // temp thread


template <typename T>
void
mcrx::thermal_equilibrium_grain::
calculate_SED_from_intensity_virtual(const blitz::ETBase<T>& intensity,
				     const array_1& m_dust,
				     const array_1& dn,
				     array_2 sed, 
				     bool add_to_sed,
				     array_2* temp) const
{
  using namespace blitz;
  if(use_cuda_)
    CUDA_calculate_SED_from_intensity(intensity, m_dust, dn, sed, add_to_sed,
				      array_2(), *temp);
  else {
    const size_t block_size=32;
    
    start_threads(intensity, m_dust, dn, sed, add_to_sed, 
		  block_size, 0, temp);
  }
}


/** Executes the thread calculation. Just loops until the list of
    cells is empty and there is nothing more to do. */
template<typename T>
void mcrx::thermal_equilibrium_grain::temp_thread<T>::operator()()
{
  if(bind_thread_)
    bind_thread(thread_number_, bank_size_);

  while (true) {
    // pop another cell block off the stack, exit if empty
    size_t cstart, cend;
    {
      // open scope for locking mutex
      boost::mutex::scoped_lock stack_lock (cell_mutex_);
      if (clist_.empty())
	break;
      cstart = clist_.back().first;
      cend = clist_.back().second;
      clist_.pop_back();
    }

    DEBUG(1,std::cout << "\tRunning cells " << cstart << "-" << cend << std::endl;);
    run_range(cstart,cend);
  }
  DEBUG(1,std::cout << "\tThread done" << std::endl;);
}


/** Calculates the SED for the range of cells specified (note that it
    is an inclusive range, ie cmax should also be calculated. */ 
template<typename T>
void 
mcrx::thermal_equilibrium_grain::temp_thread<T>::
run_range(size_t cmin, size_t cmax)
{
  using namespace blitz;

  const int ns=g_.sigma().extent(firstDim);
  const int nl=g_.sigma().extent(secondDim);
  //const array_2 dintensity
  //(cast<double>(intensity_(Range(cmin,cmax), Range::all())));

  if (!add_to_sed_)
    sed_(Range(cmin,cmax),Range::all())=0;

  const size_t s_block_size = 1000;
  const size_t c_block_size = 1000;

  // Block loop in c
  for(int bc=cmin; bc<=cmax; bc+=c_block_size) {
    int bc_end=bc+c_block_size;
    if (bc_end>cmax+1)
      bc_end=cmax+1;

    // Block loop in s
    for(int bs=0; bs<ns; bs+=s_block_size) {
      int bs_end=bs+s_block_size;
      if (bs_end>ns)
	bs_end=ns;
      
      // Element loop in s
      for(int s=bs;s<bs_end;++s) {
	// Element loop in c
	for(int c=bc; c<bc_end; ++c) {
	  if(m_dust_(c)==0)
	    continue;

	  const T_float heating = 
	    g_.calculate_heating(s, intensity_(c,Range::all()));
	  T_float temp;
	  if(use_lookup_) {
	    temp = g_.interpolate_T(s, heating);
	  }
	  else {
	    temp = g_.solve_for_T(s, heating);
	  }
	  
	  assert(temp>=0);
	  assert(temp<1e4);
	  
	  tempsed_ = g_.esigma()(s, Range::all())*
	    mcrx::B_lambda(g_.invelambda(), 1./temp)*
	    4*constants::pi;
	  const T_float lum=integrate_quantity(tempsed_, g_.elambda(), false);
	  if(lum>0) {
	    if(abs(lum/heating-1)>1e-2)
	      cout << "Warning: poor energy conservation in grain: " << lum/heating-1 << '\t' << s << ',' << c <<endl;
	    
	    // correct so that we conserve luminosity exactly 
	    sed_(c,Range::all()) += tempsed_*m_dust_(c)*dnd_(s)*heating/lum;
	  }

	  //if heating and temp arrays are specified, save those data, too.  
	  if(heating_) 
	    (*heating_)(s,c) = heating;
	  if(temp_) 
	    (*temp_)(s,c) = temp; 
	}
      }
    }
  }
}




template<typename T>
void mcrx::thermal_equilibrium_grain::
start_threads(const blitz::ETBase<T>& intensity, const array_1& m_dust, 
	      const array_1& dn, array_2& sed, 
	      bool add_to_sed,
	      size_t block_size,
	      array_2* heating,
	      array_2* temp) const
{
  using namespace blitz;
  using namespace std;

  const size_t nc = intensity.unwrap().ubound(firstDim)-intensity.unwrap().lbound(firstDim)+1;

  if(use_lookup_)
    // make sure interpolators are initialized
    interpolate_T(0, 0.0);

  if(block_size==0)
    block_size=size_t(ceil(1.0*nc/n_threads_));
  DEBUG(1,cout << "Block size is " << block_size << endl;);

  // make list
  vector<pair<size_t, size_t> > clist;
  for (size_t c=0; c<nc; c+=block_size) {
    size_t cend= c+block_size-1;
    if (cend>=nc)
      cend=nc-1;
    clist.push_back(make_pair(c,cend));
  }

  // create threads
  boost::thread_group threads;
  std::vector<boost::shared_ptr<temp_thread<T> > > thread_objects;
  boost::mutex cell_mutex;
  for (int i = 0; i < n_threads_; ++i) {
    thread_objects.push_back(boost::shared_ptr<temp_thread<T> > 
			     (new temp_thread<T>
			      (i, *this, intensity, dn, m_dust,
			       heating, temp, sed,
			       clist, cell_mutex,
			       add_to_sed, use_lookup_, 
			       bind_threads_, bank_size_)));
  }
  cout << " Spawning " << n_threads_ << " threads" << endl;
  for (int i = 0; i < n_threads_; ++i) {
    threads.create_thread(*thread_objects[i]);
  }
  // wait for them to die
  threads.join_all();
  cout << " Threads done." << endl;
}


template<typename T>
void
mcrx::thermal_equilibrium_grain::
CUDA_calculate_SED_from_intensity(const blitz::ETBase<T>& intensity,
				  const array_1& m_dust,
				  const array_1& dn,
				  array_2 sed, 
				  bool add_to_sed,
				  array_2 heating,
				  array_2 temp) const
{
  using namespace std;
  using namespace blitz;

  cout << "Calculating thermal equilibrium grain emission" << endl;

#ifdef WITH_CUDA
  // bite the bullet and allocate an array for the intensities,
  // because we can't pass the expression to the code compiled with
  // nvcc because nvcc can't handle blitz.

  array_2 intensity_array(shape(intensity.unwrap().ubound(firstDim)-intensity.unwrap().lbound(firstDim)+1,
				intensity.unwrap().ubound(secondDim)-intensity.unwrap().lbound(secondDim)+1),
			  contiguousArray);
  intensity_array = intensity.unwrap();

  const size_t nc = intensity_array.extent(firstDim);

  // If the mass unit of the grain_size distribution is Msun, it's
  // very likely that dn overflows a float. To get around that, we
  // renormalize it and supply that
  // normalization to the kernel.
  T_float sed_norm = mean(dn);
  if(sed_norm>1e-2*blitz::huge(float()))
    sed_norm=1e-2*blitz::huge(float());
  array_1 renorm_dn(dn/sed_norm);
  assert(max(renorm_dn)<blitz::huge(float()));

  // some sanity checks
  assert(sigma().isStorageContiguous());
  assert(sigma().isMinorRank(secondDim));
  assert(sigma().extent(firstDim)==dn.size());
  assert(sigma().extent(secondDim)==alambda().size());
  assert(sigma().isRankStoredAscending(firstDim));
  assert(sigma().isRankStoredAscending(secondDim));
  assert(esigma().isStorageContiguous());
  assert(esigma().isMinorRank(secondDim));
  assert(esigma().extent(firstDim)==dn.size());
  assert(esigma().extent(secondDim)==elambda().size());
  assert(esigma().isRankStoredAscending(firstDim));
  assert(esigma().isRankStoredAscending(secondDim));
  assert(intensity_array.isStorageContiguous());
  assert(intensity_array.isMinorRank(secondDim));
  assert(alambda().isStorageContiguous());
  assert(elambda().isStorageContiguous());
  assert(asizes().isStorageContiguous());
  assert(m_dust.isStorageContiguous());
  assert(renorm_dn.isStorageContiguous());
  assert(renorm_dn.size()==n_size());
  assert(sed.isStorageContiguous());
  assert(sed.isMinorRank(secondDim));
  assert(sed.extent(firstDim)==nc);
  assert(sed.extent(secondDim)==n_elambda());

  Array<float, 2> fheating(heating.shape());
  Array<float, 2> ftemp(temp.shape());
  
  mcrxcuda::calculate_equilibrium_SED
    (sigma().dataFirst(), esigma().dataFirst(),
     intensity_array.dataFirst(), alambda().dataFirst(), elambda().dataFirst(), 
     asizes().dataFirst(), m_dust.dataFirst(), renorm_dn.dataFirst(), 
     sed.dataFirst(),
     n_size(), nc, n_lambda(), n_elambda(), add_to_sed, accuracy_,
     sed_norm, 
     (fheating.size()>0) ? fheating.dataFirst() : 0, 
     (ftemp.size()>0) ? ftemp.dataFirst() : 0);

  // This will only do something if the arrays have size, so it's fine
  heating=cast<T_float>(fheating);
  temp=cast<T_float>(ftemp);
#else
  assert(0);
#endif
}




#endif
